{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demonstration of MCE IRL code & environments\n",
    "\n",
    "This is just tabular environments & vanilla MCE IRL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy\n",
    "import torch as th\n",
    "\n",
    "import imitation.algorithms.tabular_irl as tirl\n",
    "import imitation.envs.examples.model_envs as menv\n",
    "\n",
    "sns.set(context='notebook')\n",
    "\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IRL on a random MDP\n",
    "\n",
    "Testing both linear reward models & MLP reward models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdp = menv.RandomMDP(\n",
    "    n_states=16,\n",
    "    n_actions=3,\n",
    "    branch_factor=2,\n",
    "    horizon=10,\n",
    "    random_obs=True,\n",
    "    obs_dim=5,\n",
    "    generator_seed=42)\n",
    "V, Q, pi = tirl.mce_partition_fh(mdp)\n",
    "Dt, D = tirl.mce_occupancy_measures(mdp, pi=pi)\n",
    "demo_counts = D @ mdp.observation_matrix\n",
    "obs_dim, = demo_counts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmodel = tirl.LinearRewardModel(obs_dim)\n",
    "opt = th.optim.Adam(rmodel.parameters(), lr=0.1)\n",
    "D_fake = tirl.mce_irl(mdp, opt, rmodel, D, linf_eps=1e-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmodel = tirl.MLPRewardModel(obs_dim, [32, 32])\n",
    "opt = th.optim.Adam(rmodel.parameters(), lr=0.1)\n",
    "D_fake = tirl.mce_irl(mdp, opt, rmodel, D, linf_eps=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Same thing, but on grid world\n",
    "\n",
    "The true reward here is not linear in the reduced feature space (i.e $(x,y)$ coordinates). Finding an appropriate linear reward is impossible (as I will demonstration), but an MLP should Just Work(tm)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Same experiments, but on grid world\n",
    "mdp = menv.CliffWorld(\n",
    "    width=7,\n",
    "    height=4,\n",
    "    horizon=8,\n",
    "    use_xy_obs=True)\n",
    "V, Q, pi = tirl.mce_partition_fh(mdp)\n",
    "Dt, D = tirl.mce_occupancy_measures(mdp, pi=pi)\n",
    "demo_counts = D @ mdp.observation_matrix\n",
    "obs_dim, = demo_counts.shape\n",
    "rmodel = tirl.LinearRewardModel(obs_dim)\n",
    "opt = th.optim.Adam(rmodel.parameters(), lr=1.0)\n",
    "D_fake = tirl.mce_irl(mdp, opt, rmodel, D, linf_eps=0.1)\n",
    "\n",
    "mdp.draw_value_vec(D)\n",
    "plt.title(\"Cliff World $p(s)$\")\n",
    "plt.xlabel('x-coord')\n",
    "plt.ylabel('y-coord')\n",
    "plt.show()\n",
    "\n",
    "mdp.draw_value_vec(D_fake)\n",
    "plt.title(\"Occupancy for linear reward function\")\n",
    "plt.show()\n",
    "plt.subplot(1, 2, 1)\n",
    "mdp.draw_value_vec(rmodel(th.as_tensor(mdp.observation_matrix)).detach().numpy())\n",
    "plt.title(\"Inferred reward\")\n",
    "plt.subplot(1, 2, 2)\n",
    "mdp.draw_value_vec(mdp.reward_matrix)\n",
    "plt.title(\"True reward\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmodel = tirl.MLPRewardModel(obs_dim, [1024,], activation=th.nn.ReLU)\n",
    "opt = th.optim.Adam(rmodel.parameters(), lr=1e-3)\n",
    "D_fake_mlp = tirl.mce_irl(\n",
    "    mdp, opt, rmodel, D, linf_eps=3e-2, print_interval=250)\n",
    "mdp.draw_value_vec(D_fake_mlp)\n",
    "plt.title(\"Occupancy for MLP reward function\")\n",
    "plt.show()\n",
    "plt.subplot(1, 2, 1)\n",
    "mdp.draw_value_vec(rmodel(th.as_tensor(mdp.observation_matrix)).detach().numpy())\n",
    "plt.title(\"Inferred reward\")\n",
    "plt.subplot(1, 2, 2)\n",
    "mdp.draw_value_vec(mdp.reward_matrix)\n",
    "plt.title(\"True reward\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the inferred reward is absolutely nothing like the true reward, but the occupancy measure still (roughly) matches the true occupancy measure."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
